// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/algorithms/tabular_exploitability.h"

#include <cmath>
#include <limits>
#include <unordered_set>

#include "open_spiel/algorithms/best_response.h"
#include "open_spiel/algorithms/expected_returns.h"
#include "open_spiel/algorithms/history_tree.h"
#include "open_spiel/policy.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"

namespace open_spiel {
namespace algorithms {

double Exploitability(const Game& game, const Policy& policy) {
  GameType game_type = game.GetType();
  if (game_type.dynamics != GameType::Dynamics::kSequential) {
    SpielFatalError("The game must be turn-based.");
  }
  if (game_type.utility != GameType::Utility::kZeroSum &&
      game_type.utility != GameType::Utility::kConstantSum) {
    SpielFatalError("The game must have zero- or constant-sum utility.");
  }

  std::unique_ptr<State> root = game.NewInitialState();
  double nash_conv = 0;
  for (auto i = Player{0}; i < game.NumPlayers(); ++i) {
    TabularBestResponse best_response(game, i, &policy);
    nash_conv += best_response.Value(*root);
  }
  return (nash_conv - game.UtilitySum()) / game.NumPlayers();
}

double Exploitability(
    const Game& game,
    const std::unordered_map<std::string, ActionsAndProbs>& policy) {
  TabularPolicy tabular_policy(policy);
  return Exploitability(game, tabular_policy);
}

double NashConv(const Game& game, const Policy& policy) {
  return NashConv(game, policy, false);
}

double NashConv(const Game& game, const Policy& policy,
                bool use_state_get_policy) {
  GameType game_type = game.GetType();
  if (game_type.dynamics != GameType::Dynamics::kSequential) {
    SpielFatalError("The game must be turn-based.");
  }

  std::unique_ptr<State> root = game.NewInitialState();
  std::vector<double> best_response_values(game.NumPlayers());
  for (auto p = Player{0}; p < game.NumPlayers(); ++p) {
    TabularBestResponse best_response(game, p, &policy);
    best_response_values[p] = best_response.Value(*root);
  }
  std::vector<double> on_policy_values =
      ExpectedReturns(*root, policy, -1, !use_state_get_policy);
  SPIEL_CHECK_EQ(best_response_values.size(), on_policy_values.size());
  double nash_conv = 0;
  for (auto p = Player{0}; p < game.NumPlayers(); ++p) {
    double deviation_incentive = best_response_values[p] - on_policy_values[p];
    if (deviation_incentive < -FloatingPointDefaultThresholdRatio()) {
      SpielFatalError(
          absl::StrCat("Negative Nash deviation incentive for player ", p, ": ",
                       deviation_incentive, ". Does you game have imperfect ",
                       "recall, or does State::ToString() not distinguish ",
                       "between unique states?"));
    }
    nash_conv += deviation_incentive;
  }
  return nash_conv;
}

double NashConv(
    const Game& game,
    const std::unordered_map<std::string, ActionsAndProbs>& policy) {
  TabularPolicy tabular_policy(policy);
  return NashConv(game, tabular_policy);
}

}  // namespace algorithms
}  // namespace open_spiel
